import torch
import numpy as np
import scipy.sparse as sp 
import struct
import imageio
import os
import argparse
import sys
import tqdm
import glob
import json
import base64
import re


from PIL import Image, ImageDraw, ImageFont
from openai import OpenAI
from qwen_vl_utils import process_vision_info
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from tree import PartTree, save_tree, load_tree
from prompts import system_prompt, root_query_prompt, query_prompt


# Constants
DEFAULT_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct"
MAX_OUTPUT_TOKENS = 4096
MAX_IMAGE_SIZE = (1120, 1120)


def load_model_and_processor(model_name: str, finetuning_path: str = None):
    """Load model and processor with optional LoRA adapter"""
    print(f"Loading model: {model_name}")
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        device_map="cuda",
    )
    processor = AutoProcessor.from_pretrained(model_name)

    return model, processor


def process_image(image_path: str = None, image=None) -> Image.Image:
    """Process and validate image input"""
    if image is not None:
        if isinstance(image, np.ndarray):
            return Image.fromarray(image, "RGB")
        return image.convert("RGB")
    if image_path and os.path.exists(image_path):
        return Image.open(image_path).convert("RGB")
    raise ValueError("No valid image provided")

def get_label_score(result: str) -> tuple[ str, float ]:
    label_pattern = r'<label>(.*?)</label>'
    label_match = re.search(label_pattern, result)
    assert label_match, "Label not found in this output"
    label_text = label_match.group(1)

    score_pattern = r'<score>(.*?)</score>'
    score_match = re.search(score_pattern, result)
    assert score_match, "Score not found in this output"
    score_text = score_match.group(1)
    score_value = float(score_text)
    
    return label_text, score_value


def generate_text_from_image(
    model, processor, images, conversation
):
    """Generate text from image using model"""
    prompt = processor.apply_chat_template(
        conversation, add_generation_prompt=True, tokenize=False
    )
    # print("Input Prompt:\n", prompt)

    image_inputs, video_inputs = process_vision_info([conversation])
    inputs = processor(text=[prompt], images=image_inputs, padding=True, return_tensors="pt")
    inputs = inputs.to('cuda')

    output_ids = model.generate(**inputs, max_new_tokens=MAX_OUTPUT_TOKENS)
    generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return output_text[0]


def label_tree(model, processor, tree):
    conversation = [
        {
            "role": "system",
            "content": system_prompt()
        },
    ]
    images = []

    def label_node(node, conversation):
        image = process_image(image=node.query_image)
        image_path = node.image_path

        # If image is too small then skip this node
        w, h = image.size
        if w == 1 or h == 1:
            return

        if node is tree.root:
            conversation.append({
                "role": "user",
                "content": [{"type": "text", "text": root_query_prompt()},
                            {"image": image}],
            })
        else:
            conversation.append({
                "role": "user",
                "content": [{"type": "text", "text": query_prompt()},
                            {"image": image}],
            })

        images.append(image)
        result = generate_text_from_image(
            model, processor, images, conversation
        )

        # print(result)

        label, score = get_label_score(result)
        caption = f"{label}: {score}"

        node.set_caption(caption)

        # TODO: verify this
        conversation.append({
            "role": "assistant",
            "content": [{"type": "text", "text": result}],
        })

        for child in node.children:
            label_node(child, conversation)

        # Pop query, response, and image
        conversation.pop()
        conversation.pop()
        images.pop()

    label_node(tree.root, conversation)

def query_tree(tree: PartTree, model, processor):
    # tree.query_preprocess(crop=True, pad=True, resize=(224, 224))
    tree.query_preprocess()
    label_tree(model, processor, tree)

def main(args):
    """Main execution flow"""
    model, processor = load_model_and_processor(
        args.model_name
    )
    tree = load_tree(args.tree_path)
    query_tree(tree, model, processor)
    tree.render_tree(os.path.join(os.path.dirname(args.tree_path), 'tree_with_labels'))
    save_tree(tree, os.path.join(os.path.dirname(args.tree_path), 'tree_labeled.pkl'))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Multi-modal inference with optional Gradio UI and LoRA support"
    )
    parser.add_argument("--tree_path", type=str, help="Path to part hierarchy tree")
    parser.add_argument(
        "--model_name", type=str, default=DEFAULT_MODEL, help="Model name"
    )

    args = parser.parse_args()
    torch.cuda.empty_cache()
    main(args)

